package cn.designpattern;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class PiggyMergeSort {
    private final static ForkJoinPool forkJoinPool = new ForkJoinPool();
    public static void piggyMerge(int []arr) {
        if(arr==null || arr.length==0 || arr.length==1){
            return;
        }
        MergeTask countTask = new MergeTask(0, arr.length-1, arr);
        ForkJoinTask<int[]> submit = forkJoinPool.submit(countTask);
        submit.join();
    }
}

class MergeTask extends RecursiveTask<int[]> {
    private final int left;
    private final int right;
    private final int[] array;
    private int number = 9;

    public void setNumber(int number) {
        this.number = number;
    }

    public MergeTask(int left, int right, int[] array) {
        super();
        this.left = left;
        this.right = right;
        this.array = array;
    }

    @Override
    protected int[] compute() {
        if (right - left >= number) {
            int num = right + left;
            int mid = num / 2;
            MergeTask one = new MergeTask(left, mid, array);
            MergeTask two = new MergeTask(mid+1, right, array);
            ForkJoinTask<int[]> fork1 = one.fork();
            ForkJoinTask<int[]> fork2 = two.fork();
            int middle = (right + left) / 2;
            int[] join1 = fork1.join();
            int[] join2 = fork2.join();
            merge(array, left, middle, right);
        }else {
            mergeSort(array, left, right);
        }
        return array;
    }

    public void mergeSort(int[] arr, int left, int right){
        if(left<right){
            int mid = (left + right) / 2;
            mergeSort(arr, left, mid);
            mergeSort(arr,mid+1, right);
            merge(arr, left, mid, right);
        }
    }

    private void merge(int[] arr, int left, int mid, int right){
        int[] temp = new int[right-left+1];
        int i = left;
        int j = mid + 1;
        int k = 0;
        while (i <= mid && j <= right){
            if(arr[i] <= arr[j]){
                temp[k++] = arr[i++];
            }else {
                temp[k++] = arr[j++];
            }
        }
        while(i <= mid){
            temp[k++] = arr[i++];
        }
        while(j <= right){
            temp[k++] = arr[j++];
        }
        // 将temp中的元素全部拷贝到原数组中
        int z = 0;
        while(left <= right){
            arr[left++] = temp[z++];
        }
    }
}
